from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from IPython.display import HTML
from matplotlib.animation import ArtistAnimation
from tqdm import tqdm
from diffdrr import DRR, load_example_ct
from diffdrr.metrics import XCorr2
from diffdrr.visualization import animate, plot_drr
np.random.seed(39)def converged(df):
return df["loss"].iloc[-1] <= -0.999# Make the ground truth X-ray
SDR = 200.0
HEIGHT = 100
DELX = 5e-2
volume, spacing = load_example_ct()
bx, by, bz = np.array(volume.shape) * np.array(spacing) / 2
true_params = {
"sdr": SDR,
"theta": torch.pi,
"phi": 0,
"gamma": torch.pi / 2,
"bx": bx,
"by": by,
"bz": bz,
}
drr = DRR(volume, spacing, height=HEIGHT, delx=DELX, device="cuda")
ground_truth = drr(**true_params)
plot_drr(ground_truth)
plt.show()
# Make a random DRR
def get_initial_parameters(true_params):
sdr = true_params["sdr"]
theta = true_params["theta"] + np.random.uniform(-np.pi / 4, np.pi / 4)
phi = true_params["phi"] + np.random.uniform(-np.pi / 3, np.pi / 3)
gamma = true_params["gamma"] + np.random.uniform(-np.pi / 3, np.pi / 3)
bx = true_params["bx"] + np.random.uniform(-30.0, 31.0)
by = true_params["by"] + np.random.uniform(-30.0, 31.0)
bz = true_params["bz"] + np.random.uniform(-30.0, 31.0)
return sdr, theta, phi, gamma, bx, by, bz
sdr, theta, phi, gamma, bx, by, bz = get_initial_parameters(true_params)
est = drr(sdr, theta, phi, gamma, bx, by, bz) # Initialize the DRR generator
plot_drr(est)
plt.show()
def optimize(
drr,
ground_truth,
lr_rotations=5.3e-2,
lr_translations=7.5e1,
momentum=0,
dampening=0,
n_itrs=250
):
criterion = XCorr2(zero_mean_normalized=True)
optimizer = torch.optim.SGD(
[
{"params": [drr.rotations], "lr": lr_rotations},
{"params": [drr.translations], "lr": lr_translations},
],
momentum=momentum,
dampening=dampening,
)
params = []
for itr in tqdm(range(n_itrs)):
estimate = drr()
theta, phi, gamma = drr.rotations.squeeze()
bx, by, bz = drr.translations.squeeze()
params.append([i.item() for i in [theta, phi, gamma, bx, by, bz]])
loss = -criterion(ground_truth, estimate)
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
if loss < -0.999:
tqdm.write(f"Converged in {itr} iterations")
break
return pd.DataFrame(params, columns=["theta", "phi", "gamma", "bx", "by", "bz"])# Base SGD
drr(sdr, theta, phi, gamma, bx, by, bz)
params_base = optimize(drr, ground_truth)
# SGD + momentum
drr(sdr, theta, phi, gamma, bx, by, bz)
params_momentum = optimize(drr, ground_truth, momentum=0.9)
# SGD + momentum + dampening
drr(sdr, theta, phi, gamma, bx, by, bz)
params_momentum_dampen = optimize(drr, ground_truth, momentum=0.9, dampening=0.1) 40%|█████████████████████████████████▏ | 101/250 [00:03<00:05, 26.04it/s]
Converged in 101 iterations
27%|██████████████████████▌ | 68/250 [00:02<00:07, 25.83it/s]
Converged in 68 iterations
21%|█████████████████▌ | 53/250 [00:02<00:07, 25.78it/s]
Converged in 53 iterations
def precompute_drrs(df, sdr, drr, ax, max_len=len(params_base)):
imgs = []
for idx, row in df.iterrows():
params = row[["theta", "phi", "gamma", "bx", "by", "bz"]].values
itr = drr(sdr, *params)
img = plot_drr(itr, animated=True, ax=ax)
if idx == 0:
plot_drr(itr, ax=ax)
imgs.append(img)
for _ in range(max_len - len(df)):
imgs.append(img)
return imgsfig, axs = plt.subplots(ncols=4, dpi=300, figsize=(10, 3), constrained_layout=True)
plot_drr(ground_truth, ax=axs[0])
imgs1 = precompute_drrs(params_base, SDR, drr, axs[1])
imgs2 = precompute_drrs(params_momentum, SDR, drr, axs[2])
imgs3 = precompute_drrs(params_momentum_dampen, SDR, drr, axs[3])
imgs = [[*ims] for ims in zip(imgs1, imgs2, imgs3)]
anim = ArtistAnimation(fig, imgs, interval=50, blit=True, repeat_delay=1000)
plt.close()
HTML(anim.to_jshtml())